{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# From Unlabeled Data to a Deployed Machine Learning Model: A SageMaker Ground Truth Demonstration for Object Detection\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Run a Ground Truth labeling job (time: about 4h)](#Run-a-Ground-Truth-labeling-job)\n",
    "    1. [Prepare the data](#Prepare-the-data)\n",
    "    2. [Specify the category](#Specify-the-categories)\n",
    "    3. [Create the instruction template](#Create-the-instruction-template)\n",
    "    4. [Create a private team to test your task [OPTIONAL]](#Create-a-private-team-to-test-your-task-[OPTIONAL])\n",
    "    5. [Define pre-built lambda functions for use in the labeling job](#Define-pre-built-lambda-functions-for-use-in-the-labeling-job)\n",
    "    6. [Submit the Ground Truth job request](#Submit-the-Ground-Truth-job-request)\n",
    "        1. [Verify your task using a private team [OPTIONAL]](#Verify-your-task-using-a-private-team-[OPTIONAL])\n",
    "    7. [Monitor job progress](#Monitor-job-progress)\n",
    "3. [Analyze Ground Truth labeling job results (time: about 20min)](#Analyze-Ground-Truth-labeling-job-results)\n",
    "    1. [Postprocess the output manifest](#Postprocess-the-output-manifest)\n",
    "    2. [Plot class histograms](#Plot-class-histograms)\n",
    "    3. [Plot annotated images](#Plot-annotated-images)\n",
    "        1. [Plot a small output sample](#Plot-a-small-output-sample)\n",
    "        2. [Plot the full results](#Plot-the-full-results)\n",
    "4. [Compare Ground Truth results to standard labels (time: about 5min)](#Compare-Ground-Truth-results-to-standard-labels)\n",
    "    1. [Compute accuracy](#Compute-accuracy)\n",
    "    2. [Plot correct and incorrect annotations](#Plot-correct-and-incorrect-annotations)\n",
    "5. [Train an object detector using Ground Truth labels (time: about 15min)](#Train-an-image-classifier-using-Ground-Truth-labels)\n",
    "6. [Deploy the Model (time: about 20min)](#Deploy-the-Model)\n",
    "    1. [Create Model](#Create-Model)\n",
    "    2. [Batch Transform](#Batch-Transform)\n",
    "7. [Review](#Review)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "This sample notebook takes you through an end-to-end workflow to demonstrate the functionality of SageMaker Ground Truth. We'll start with an unlabeled image data set, acquire bounding boxes for objects in the images using SageMaker Ground Truth, analyze the results, train an object detector, host the resulting model, and, finally, use it to make predictions. Before you begin, we highly recommend you start a Ground Truth labeling job through the AWS Console first to familiarize yourself with the workflow. The AWS Console offers less flexibility than the API, but is simple to use.\n",
    "\n",
    "#### Cost and runtime\n",
    "You can run this demo in two modes:\n",
    "1. Set `RUN_FULL_AL_DEMO = True` in the next cell to label 1000 images. This should cost about $200 given the current [Ground Truth pricing scheme](https://aws.amazon.com/sagemaker/groundtruth/pricing/). In order to reduce the cost, we will use Ground Truth's auto-labeling feature. Auto-labeling uses computer vision to learn from human responses and automatically create bounding boxes for the easiest images at a cheap price. The total end-to-end runtime should be about 6h.\n",
    "1. Set `RUN_FULL_AL_DEMO = False` in the next cell to label only 100 images. This should cost \\$26. **Since Ground Truth's auto-labeling feature only kicks in for datasets of 1000 images or more, this cheaper version of the demo will not use it. Some of the analysis plots might look awkward, but you should still be able to see good results on the human-annotated 100 images.**\n",
    "\n",
    "#### Prerequisites\n",
    "To run this notebook, you can simply execute each cell in order. To understand what's happening, you'll need:\n",
    "* An S3 bucket you can write to -- please provide its name in the following cell. The bucket must be in the same region as this SageMaker Notebook instance. You can also change the `EXP_NAME` to any valid S3 prefix. All the files related to this experiment will be stored in that prefix of your bucket. \n",
    "* Familiarity with Python and [numpy](http://www.numpy.org/).\n",
    "* Basic familiarity with [AWS S3](https://docs.aws.amazon.com/s3/index.html).\n",
    "* Basic understanding of [AWS Sagemaker](https://aws.amazon.com/sagemaker/).\n",
    "* Basic familiarity with [AWS Command Line Interface (CLI)](https://aws.amazon.com/cli/) -- ideally, you should have it set up with credentials to access the AWS account you're running this notebook from.\n",
    "\n",
    "This notebook has only been tested on a SageMaker notebook instance. The runtimes given are approximate. We used an `ml.m4.xlarge` instance in our tests. However, you can likely run it on a local instance by first executing the cell below on SageMaker and then copying the `role` string to your local copy of the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "from collections import namedtuple\n",
    "from collections import defaultdict\n",
    "from collections import Counter\n",
    "from datetime import datetime\n",
    "import itertools\n",
    "import base64\n",
    "import glob\n",
    "import json\n",
    "import random\n",
    "import time\n",
    "import imageio\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import shutil\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import boto3\n",
    "import botocore\n",
    "import sagemaker\n",
    "from urllib.parse import urlparse\n",
    "\n",
    "BUCKET = '<< YOUR S3 BUCKET NAME >>'\n",
    "EXP_NAME = 'ground-truth-od-full-demo' # Any valid S3 prefix.\n",
    "RUN_FULL_AL_DEMO = True # See 'Cost and Runtime' in the Markdown cell above!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure the bucket is in the same region as this notebook.\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.session.Session().region_name\n",
    "s3 = boto3.client('s3')\n",
    "bucket_region = s3.head_bucket(Bucket=BUCKET)['ResponseMetadata']['HTTPHeaders']['x-amz-bucket-region']\n",
    "assert bucket_region == region, \"Your S3 bucket {} and this notebook need to be in the same region.\".format(BUCKET)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run a Ground Truth labeling job\n",
    "\n",
    "**This section should take about 4 hours to complete.**\n",
    "\n",
    "We will first run a labeling job. This involves several steps: collecting the images we want annotated, creating instructions, and writing a labeling job specification. In addition, we highly recommend that you run a (free) mock job using a private workforce before you submit any job to the public workforce. This notebook will explain how to do that as an optional step. Using a public workforce, this section should take about 4 hours. However, this will vary depending on the availability of workers.\n",
    "\n",
    "### Prepare the data\n",
    "We will first download images and labels of a subset of the [Google Open Images Dataset](https://storage.googleapis.com/openimages/web/index.html). These labels were [carefully verified](https://storage.googleapis.com/openimages/web/factsfigures.html). Later, we will compare Ground Truth annotations to these labels. Our dataset will consist of images of various species of bird.\n",
    "\n",
    "If you chose `RUN_FULL_AL_DEMO = False`, then we will choose a subset of 100 images from this dataset. This is a diverse dataset of interesting images, and it should be fun for the human annotators to work with. You are free to ask the annotators to annotate any images you wish as long as the images do not contain adult content. In this case, you must adjust the labeling job request this job produces; please check the Ground Truth documentation.\n",
    "\n",
    "We will copy these images to our local `BUCKET` and create a corresponding *input manifest*. The input manifest is a formatted list of the S3 locations of the images we want Ground Truth to annotate. We will upload this manifest to our S3 `BUCKET`.\n",
    "\n",
    "#### Disclosure regarding the Open Images Dataset V4:\n",
    "Open Images Dataset V4 is created by Google Inc. We have not modified the images or the accompanying annotations. You can obtain the images and the annotations [here](https://storage.googleapis.com/openimages/web/download.html). The annotations are licensed by Google Inc. under [CC BY 4.0](https://creativecommons.org/licenses/by/2.0/) license. The images are listed as having a [CC BY 2.0](https://creativecommons.org/licenses/by/2.0/) license. The following paper describes Open Images V4 in depth: from the data collection and annotation to detailed statistics about the data and evaluation of models trained on it.\n",
    "\n",
    "A. Kuznetsova, H. Rom, N. Alldrin, J. Uijlings, I. Krasin, J. Pont-Tuset, S. Kamali, S. Popov, M. Malloci, T. Duerig, and V. Ferrari.\n",
    "*The Open Images Dataset V4: Unified image classification, object detection, and visual relationship detection at scale.* arXiv:1811.00982, 2018. ([link to PDF](https://arxiv.org/abs/1811.00982))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Download and process the Open Images annotations.\n",
    "!wget https://storage.googleapis.com/openimages/2018_04/test/test-annotations-bbox.csv\n",
    "!wget https://storage.googleapis.com/openimages/2018_04/bbox_labels_600_hierarchy.json\n",
    "    \n",
    "with open('bbox_labels_600_hierarchy.json', 'r') as f:\n",
    "    hierarchy = json.load(f)\n",
    "    \n",
    "CLASS_NAME = 'Bird'\n",
    "CLASS_ID = '/m/015p6'\n",
    "\n",
    "# Find all the subclasses of the desired image class (e.g. 'swans' and 'pigeons' etc if CLASS_NAME=='Bird').\n",
    "good_subclasses = set()\n",
    "def get_all_subclasses(hierarchy, good_subtree=False):\n",
    "    if hierarchy['LabelName'] == CLASS_ID:\n",
    "        good_subtree = True\n",
    "    if good_subtree:\n",
    "        good_subclasses.add(hierarchy['LabelName'])\n",
    "    if 'Subcategory' in hierarchy:            \n",
    "        for subcat in hierarchy['Subcategory']:\n",
    "            get_all_subclasses(subcat, good_subtree=good_subtree)\n",
    "    return good_subclasses\n",
    "good_subclasses = get_all_subclasses(hierarchy)\n",
    "\n",
    "# Find an appropriate number of images with at least one bounding box in the desired category\n",
    "if RUN_FULL_AL_DEMO:\n",
    "    n_ims = 1000\n",
    "else:\n",
    "    n_ims = 100\n",
    "    \n",
    "fids2bbs = defaultdict(list)\n",
    "# Skip images with risky content.\n",
    "skip_these_images = ['251d4c429f6f9c39', \n",
    "                    '065ad49f98157c8d']\n",
    "\n",
    "with open('test-annotations-bbox.csv', 'r') as f:\n",
    "    for line in f.readlines()[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        img_id, _, cls_id, conf, xmin, xmax, ymin, ymax, *_ = line\n",
    "        if img_id in skip_these_images:\n",
    "            continue\n",
    "        if cls_id in good_subclasses:\n",
    "            fids2bbs[img_id].append([CLASS_NAME, xmin, xmax, ymin, ymax])\n",
    "            if len(fids2bbs) == n_ims:\n",
    "                break\n",
    "\n",
    "# Copy the images to our local bucket.\n",
    "s3 = boto3.client('s3')\n",
    "for img_id_id, img_id in enumerate(fids2bbs.keys()):\n",
    "    if img_id_id % 100 == 0:\n",
    "        print('Copying image {} / {}'.format(img_id_id, n_ims))\n",
    "    copy_source = {\n",
    "        'Bucket': 'open-images-dataset',\n",
    "        'Key': 'test/{}.jpg'.format(img_id)\n",
    "    }\n",
    "    s3.copy(copy_source, BUCKET, '{}/images/{}.jpg'.format(EXP_NAME, img_id))\n",
    "print('Done!')\n",
    "\n",
    "# Create and upload the input manifest.\n",
    "manifest_name = 'input.manifest'\n",
    "with open(manifest_name, 'w') as f:\n",
    "    for img_id_id, img_id in enumerate(fids2bbs.keys()):\n",
    "        img_path = 's3://{}/{}/images/{}.jpg'.format(BUCKET, EXP_NAME, img_id)\n",
    "        f.write('{\"source-ref\": \"' + img_path +'\"}\\n')\n",
    "s3.upload_file(manifest_name, BUCKET, EXP_NAME + '/' + manifest_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After running the cell above, you should be able to go to `s3://BUCKET/EXP_NAME/images` in the [S3 console](https://console.aws.amazon.com/s3/) and see 1000 images (or 100 if you have set `RUN_FULL_AL_DEMO = False`). We recommend you inspect these images! You can download them to a local machine using the AWS CLI."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the categories\n",
    "\n",
    "To run an object detection labeling job, you must decide on a set of classes the annotators can choose from. At the moment, Ground Truth only supports annotating one OD class at a time. In our case, the singleton class list is simply `[\"Bird\"]`.  To work with Ground Truth, this list needs to be converted to a .json file and uploaded to the S3 `BUCKET`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLASS_LIST = [CLASS_NAME]\n",
    "print(\"Label space is {}\".format(CLASS_LIST))\n",
    "\n",
    "json_body = {\n",
    "    'labels': [{'label': label} for label in CLASS_LIST]\n",
    "}\n",
    "with open('class_labels.json', 'w') as f:\n",
    "    json.dump(json_body, f)\n",
    "    \n",
    "s3.upload_file('class_labels.json', BUCKET, EXP_NAME + '/class_labels.json')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should now see `class_labels.json` in `s3://BUCKET/EXP_NAME/`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the instruction template\n",
    "\n",
    "Part or all of your images will be annotated by human annotators. It is **essential** to provide good instructions. Good instructions are:\n",
    "1. Concise. We recommend limiting verbal/textual instruction to two sentences and focusing on clear visuals.\n",
    "2. Visual. In the case of object detection, we recommend providing several labeled examples with different numbers of boxes.\n",
    "\n",
    "When used through the AWS Console, Ground Truth helps you create the instructions using a visual wizard. When using the API, you need to create an HTML template for your instructions. Below, we prepare a very simple but effective template and upload it to your S3 bucket.\n",
    "\n",
    "NOTE: If you use any images in your template (as we do), they need to be publicly accessible. You can enable public access to files in your S3 bucket through the S3 Console, as described in [S3 Documentation](https://docs.aws.amazon.com/AmazonS3/latest/user-guide/set-object-permissions.html). \n",
    "\n",
    "#### Testing your instructions\n",
    "**It is very easy to create broken instructions.** This might cause your labeling job to fail. However, it might also cause your job to complete with meaningless results if, for example, the annotators have no idea what to do or the instructions are misleading. At the moment the only way to test the instructions is to run your job in a private workforce. This is a way to run a mock labeling job for free. We describe how in [Verify your task using a private team [OPTIONAL]](#Verify-your-task-using-a-private-team-[OPTIONAL]).\n",
    "\n",
    "It is helpful to show examples of correctly labeled images in the instructions. The following code block produces several such examples for our dataset and saves them in `s3://BUCKET/EXP_NAME/`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot sample images.\n",
    "def plot_bbs(ax, bbs, img):\n",
    "    '''Add bounding boxes to images.'''\n",
    "    ax.imshow(img)\n",
    "    imh, imw, _ = img.shape\n",
    "    for bb in bbs:\n",
    "        xmin, xmax, ymin, ymax = bb\n",
    "        xmin *= imw\n",
    "        xmax *= imw\n",
    "        ymin *= imh\n",
    "        ymax *= imh\n",
    "        rec = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, fill=None, lw=4, edgecolor='blue')\n",
    "        ax.add_patch(rec)\n",
    "        \n",
    "plt.figure(facecolor='white', dpi=100, figsize=(3, 7))\n",
    "plt.suptitle('Please draw a box\\n around each {}\\n like the examples below.\\n Thank you!'.format(CLASS_NAME), fontsize=15)\n",
    "for fid_id, (fid, bbs) in enumerate([list(fids2bbs.items())[idx] for idx in [1, 3]]):\n",
    "    !aws s3 cp s3://open-images-dataset/test/{fid}.jpg .\n",
    "    img = imageio.imread(fid + '.jpg')\n",
    "    bbs = [[float(a) for a in annot[1:]] for annot in bbs]\n",
    "    ax = plt.subplot(2, 1, fid_id+1)\n",
    "    plot_bbs(ax, bbs, img)\n",
    "    plt.axis('off')\n",
    "    \n",
    "plt.savefig('instructions.png', dpi=60)\n",
    "with open('instructions.png', 'rb') as instructions:\n",
    "    instructions_uri = base64.b64encode(instructions.read()).decode('utf-8').replace('\\n', '')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import HTML, display\n",
    "\n",
    "def make_template(test_template=False, save_fname='instructions.template'):\n",
    "    template = r\"\"\"<script src=\"https://assets.crowd.aws/crowd-html-elements.js\"></script>\n",
    "    <crowd-form>\n",
    "      <crowd-bounding-box\n",
    "        name=\"boundingBox\"\n",
    "        src=\"{{{{ task.input.taskObject | grant_read_access }}}}\"\n",
    "        header=\"Dear Annotator, please draw a tight box around each {class_name} you see (if there are more than 8 birds, draw boxes around at least 8). Thank you!\"\n",
    "        labels=\"{labels_str}\"\n",
    "      >\n",
    "        <full-instructions header=\"Please annotate each {class_name}.\">\n",
    "\n",
    "    <ol>\n",
    "        <li><strong>Inspect</strong> the image</li>\n",
    "        <li><strong>Determine</strong> if the specified label is/are visible in the picture.</li>\n",
    "        <li><strong>Outline</strong> each instance of the specified label in the image using the provided “Box” tool.</li>\n",
    "    </ol>\n",
    "    <ul>\n",
    "        <li>Boxes should fit tight around each object</li>\n",
    "        <li>Do not include parts of the object are overlapping or that cannot be seen, even though you think you can interpolate the whole shape.</li>\n",
    "        <li>Avoid including shadows.</li>\n",
    "        <li>If the target is off screen, draw the box up to the edge of the image.</li>\n",
    "    </ul>\n",
    "\n",
    "        </full-instructions>\n",
    "        <short-instructions>\n",
    "        <img src=\"data:image/png;base64,{instructions_uri}\" style=\"max-width:100%\">\n",
    "        </short-instructions>\n",
    "      </crowd-bounding-box>\n",
    "    </crowd-form>\n",
    "    \"\"\".format(class_name=CLASS_NAME,\n",
    "               instructions_uri=instructions_uri,\n",
    "               labels_str=str(CLASS_LIST) if test_template else '{{ task.input.labels | to_json | escape }}')\n",
    "    with open(save_fname, 'w') as f:\n",
    "        f.write(template)\n",
    "\n",
    "        \n",
    "make_template(test_template=True, save_fname='instructions.html')\n",
    "make_template(test_template=False, save_fname='instructions.template')\n",
    "s3.upload_file('instructions.template', BUCKET, EXP_NAME + '/instructions.template')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should now be able to find your template in `s3://BUCKET/EXP_NAME/instructions.template`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a private team to test your task [OPTIONAL]\n",
    "This step requires you to use the AWS Console. However, we **highly recommend** that you follow it, especially when creating your own task with a custom dataset, label set, and template.\n",
    "\n",
    "We will create a `private workteam` and add only one user (you) to it. Then, we will modify the Ground Truth API job request to send the task to that workforce. You will then be able to see your annotation job exactly as the public annotators would see it. You could even annotate the whole dataset yourself! \n",
    "\n",
    "To create a private team:\n",
    "1. Go to `AWS Console > Amazon SageMaker > Labeling workforces`\n",
    "2. Click \"Private\" and then \"Create private team\". \n",
    "3. Enter the desired name for your private workteam.\n",
    "4. Select \"Create a new Amazon Cognito user group\" and click \"Create private team.\"\n",
    "5. The AWS Console should now return to `AWS Console > Amazon SageMaker > Labeling workforces`.\n",
    "6. Click on \"Invite new workers\" in the \"Workers\" tab.\n",
    "7. Enter your own email address in the \"Email addresses\" section and click \"Invite new workers.\"\n",
    "8. Click on your newly created team under the \"Private teams\" tab.\n",
    "9. Select the \"Workers\" tab and click \"Add workers to team.\"\n",
    "10. Select your email and click \"Add workers to team.\"\n",
    "11. The AWS Console should again return to `AWS Console > Amazon SageMaker > Labeling workforces`. Your newly created team should be visible under \"Private teams\". Next to it you will see an `ARN` which is a long string that looks like `arn:aws:sagemaker:region-name-123456:workteam/private-crowd/team-name`. Copy this ARN into the cell below.\n",
    "12. You should get an email from `no-reply@verificationemail.com` that contains your workforce username and password. \n",
    "13. In `AWS Console > Amazon SageMaker > Labeling workforces > Private`, click on the URL under `Labeling portal sign-in URL`. Use the email/password combination from the previous step to log in (you will be asked to create a new, non-default password).\n",
    "\n",
    "That's it! This is your private worker's interface. When we create a verification task in [Verify your task using a private team](#Verify-your-task-using-a-private-team-[OPTIONAL]) below, your task should appear in this window. You can invite your colleagues to participate in the labeling job by clicking the \"Invite new workers\" button.\n",
    "\n",
    "The [SageMaker Ground Truth documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-workforce-management-private.html) has more details on the management of private workteams. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "private_workteam_arn = '<< your private workteam ARN here >>'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define pre-built lambda functions for use in the labeling job \n",
    "\n",
    "Before we submit the request, we need to define the ARNs for four key components of the labeling job: 1) the workteam, 2) the annotation consolidation Lambda function, 3) the pre-labeling task Lambda function, and 4) the machine learning algorithm to perform auto-annotation. These functions are defined by strings with region names and AWS service account numbers, so we will define a mapping below that will enable you to run this notebook in any of our supported regions. \n",
    "\n",
    "See the official documentation for the available ARNs:\n",
    "* [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/API_HumanTaskConfig.html#SageMaker-Type-HumanTaskConfig-PreHumanTaskLambdaArn) for available pre-human ARNs for other workflows.\n",
    "* [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/API_AnnotationConsolidationConfig.html#SageMaker-Type-AnnotationConsolidationConfig-AnnotationConsolidationLambdaArn) for available annotation consolidation ANRs for other workflows.\n",
    "* [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/API_LabelingJobAlgorithmsConfig.html#SageMaker-Type-LabelingJobAlgorithmsConfig-LabelingJobAlgorithmSpecificationArn) for available auto-labeling ARNs for other workflows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify ARNs for resources needed to run an object detection job.\n",
    "ac_arn_map = {'us-west-2': '081040173940',\n",
    "              'us-east-1': '432418664414',\n",
    "              'us-east-2': '266458841044',\n",
    "              'eu-west-1': '568282634449',\n",
    "              'ap-northeast-1': '477331159723'}\n",
    "\n",
    "prehuman_arn = 'arn:aws:lambda:{}:{}:function:PRE-BoundingBox'.format(region, ac_arn_map[region])\n",
    "acs_arn = 'arn:aws:lambda:{}:{}:function:ACS-BoundingBox'.format(region, ac_arn_map[region]) \n",
    "labeling_algorithm_specification_arn = 'arn:aws:sagemaker:{}:027400017018:labeling-job-algorithm-specification/object-detection'.format(region)\n",
    "workteam_arn = 'arn:aws:sagemaker:{}:394669845002:workteam/public-crowd/default'.format(region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Submit the Ground Truth job request\n",
    "The API starts a Ground Truth job by submitting a request. The request contains the \n",
    "full configuration of the annotation task, and allows you to modify the fine details of\n",
    "the job that are fixed to default values when you use the AWS Console. The parameters that make up the request are described in more detail in the [SageMaker Ground Truth documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateLabelingJob.html).\n",
    "\n",
    "After you submit the request, you should be able to see the job in your AWS Console, at `Amazon SageMaker > Labeling Jobs`.\n",
    "You can track the progress of the job there. This job will take several hours to complete. If your job\n",
    "is larger (say 100,000 images), the speed and cost benefit of auto-labeling should be larger.\n",
    "\n",
    "### Verify your task using a private team [OPTIONAL]\n",
    "If you chose to follow the steps in [Create a private team](#Create-a-private-team-to-test-your-task-[OPTIONAL]), you can first verify that your task runs as expected. To do this:\n",
    "1. Set VERIFY_USING_PRIVATE_WORKFORCE to True in the cell below.\n",
    "2. Run the next two cells. This will define the task and submit it to the private workforce (you).\n",
    "3. After a few minutes, you should be able to see your task in your private workforce interface [Create a private team](#Create-a-private-team-to-test-your-task-[OPTIONAL]).\n",
    "Please verify that the task appears as you want it to appear.\n",
    "4. If everything is in order, change `VERIFY_USING_PRIVATE_WORKFORCE` to `False` and rerun the cell below to start the real annotation task!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "VERIFY_USING_PRIVATE_WORKFORCE = False\n",
    "USE_AUTO_LABELING = True\n",
    "\n",
    "task_description = 'Dear Annotator, please draw a box around each {}. Thank you!'.format(CLASS_NAME)\n",
    "task_keywords = ['image', 'object', 'detection']\n",
    "task_title = 'Please draw a box around each {}.'.format(CLASS_NAME)\n",
    "job_name = 'ground-truth-od-demo-' + str(int(time.time()))\n",
    "\n",
    "human_task_config = {\n",
    "      \"AnnotationConsolidationConfig\": {\n",
    "        \"AnnotationConsolidationLambdaArn\": acs_arn,\n",
    "      },\n",
    "      \"PreHumanTaskLambdaArn\": prehuman_arn,\n",
    "      \"MaxConcurrentTaskCount\": 200, # 200 images will be sent at a time to the workteam.\n",
    "      \"NumberOfHumanWorkersPerDataObject\": 5, # We will obtain and consolidate 5 human annotations for each image.\n",
    "      \"TaskAvailabilityLifetimeInSeconds\": 21600, # Your workteam has 6 hours to complete all pending tasks.\n",
    "      \"TaskDescription\": task_description,\n",
    "      \"TaskKeywords\": task_keywords,\n",
    "      \"TaskTimeLimitInSeconds\": 300, # Each image must be labeled within 5 minutes.\n",
    "      \"TaskTitle\": task_title,\n",
    "      \"UiConfig\": {\n",
    "        \"UiTemplateS3Uri\": 's3://{}/{}/instructions.template'.format(BUCKET, EXP_NAME),\n",
    "      }\n",
    "    }\n",
    "\n",
    "if not VERIFY_USING_PRIVATE_WORKFORCE:\n",
    "    human_task_config[\"PublicWorkforceTaskPrice\"] = {\n",
    "        \"AmountInUsd\": {\n",
    "           \"Dollars\": 0,\n",
    "           \"Cents\": 3,\n",
    "           \"TenthFractionsOfACent\": 6,\n",
    "        }\n",
    "    } \n",
    "    human_task_config[\"WorkteamArn\"] = workteam_arn\n",
    "else:\n",
    "    human_task_config[\"WorkteamArn\"] = private_workteam_arn\n",
    "\n",
    "ground_truth_request = {\n",
    "        \"InputConfig\" : {\n",
    "          \"DataSource\": {\n",
    "            \"S3DataSource\": {\n",
    "              \"ManifestS3Uri\": 's3://{}/{}/{}'.format(BUCKET, EXP_NAME, manifest_name),\n",
    "            }\n",
    "          },\n",
    "          \"DataAttributes\": {\n",
    "            \"ContentClassifiers\": [\n",
    "              \"FreeOfPersonallyIdentifiableInformation\",\n",
    "              \"FreeOfAdultContent\"\n",
    "            ]\n",
    "          },  \n",
    "        },\n",
    "        \"OutputConfig\" : {\n",
    "          \"S3OutputPath\": 's3://{}/{}/output/'.format(BUCKET, EXP_NAME),\n",
    "        },\n",
    "        \"HumanTaskConfig\" : human_task_config,\n",
    "        \"LabelingJobName\": job_name,\n",
    "        \"RoleArn\": role, \n",
    "        \"LabelAttributeName\": \"category\",\n",
    "        \"LabelCategoryConfigS3Uri\": 's3://{}/{}/class_labels.json'.format(BUCKET, EXP_NAME),\n",
    "    }\n",
    "\n",
    "if USE_AUTO_LABELING and RUN_FULL_AL_DEMO:\n",
    "    ground_truth_request[ \"LabelingJobAlgorithmsConfig\"] = {\n",
    "            \"LabelingJobAlgorithmSpecificationArn\": labeling_algorithm_specification_arn\n",
    "                                       }\n",
    "    \n",
    "sagemaker_client = boto3.client('sagemaker')\n",
    "sagemaker_client.create_labeling_job(**ground_truth_request)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Monitor job progress\n",
    "A Ground Truth job can take a few hours to complete (if your dataset is larger than 10000 images, it can take much longer than that!). One way to monitor the job's progress is through AWS Console. In this notebook, we will use Ground Truth output files and Cloud Watch logs in order to monitor the progress.\n",
    "\n",
    "You can re-evaluate the next cell repeatedly. It sends a `describe_labeling_job` request which should tell you whether the job is completed or not. If it is, then 'LabelingJobStatus' will be 'Completed'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client = boto3.client('sagemaker')\n",
    "sagemaker_client.describe_labeling_job(LabelingJobName=job_name)['LabelingJobStatus']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next cell extracts detailed information on how your job is doing. You can re-evaluate it at any time. It should give you:\n",
    "* The number of human and machine-annotated images across the iterations of your labeling job.\n",
    "* The training curves of any neural network training jobs launched by Ground Truth **(only if you are running with `RUN_FULL_AL_DEMO=True`)**.\n",
    "* The cost of the human- and machine-annotated labels.\n",
    "\n",
    "To understand the pricing, study [this document](https://aws.amazon.com/sagemaker/groundtruth/pricing/) carefully. In our case, each human label costs `$0.08 + 5 * $0.036 = $0.26` and each auto-label costs `$0.08`. If you set `RUN_FULL_AL_DEMO=True`, there is also the added cost of using SageMaker instances for neural net training and inference during auto-labeling. However, this should be insignificant compared to the other costs.\n",
    "\n",
    "If `RUN_FULL_AL_DEMO==True`, then the job will proceed in multiple iterations. \n",
    "* Iteration 1: Ground Truth will send out 10 images as 'probes' for human annotation. If these are successfully annotated, proceed to Iteration 2.\n",
    "* Iteration 2: Send out a batch of `MaxConcurrentTaskCount - 10` (in our case, 190) images for human annotation to obtain an active learning training batch.\n",
    "* Iteration 3: Send out another batch of 200 images for human annotation to obtain an active learning validation set.\n",
    "* Iteration 4a: Train a neural net to do auto-labeling. Auto-label as many data points as possible. \n",
    "* Iteration 4b: If there is any data leftover, send out at most 200 images for human annotation.\n",
    "* Repeat Iteration 4a and 4b until all data is annotated.\n",
    "\n",
    "If `RUN_FULL_AL_DEMO==False`, only Iterations 1 and 2 will happen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HUMAN_PRICE = 0.26\n",
    "AUTO_PRICE = 0.08\n",
    "\n",
    "try:\n",
    "    os.makedirs('od_output_data/', exist_ok=False)\n",
    "except FileExistsError:\n",
    "    shutil.rmtree('od_output_data/')\n",
    "    \n",
    "S3_OUTPUT = boto3.client('sagemaker').describe_labeling_job(LabelingJobName=job_name)[\n",
    "    'OutputConfig']['S3OutputPath'] + job_name\n",
    "\n",
    "# Count number of human annotations in each class each iteration.\n",
    "!aws s3 cp {S3_OUTPUT + '/annotations/consolidated-annotation/consolidation-response'} od_output_data/consolidation-response --recursive --quiet\n",
    "consolidated_nboxes = defaultdict(int)\n",
    "consolidated_nims = defaultdict(int)\n",
    "consolidation_times = {}\n",
    "consolidated_cost_times = []\n",
    "obj_ids = set()\n",
    "\n",
    "for consolidated_fname in glob.glob('od_output_data/consolidation-response/**', recursive=True):\n",
    "    if consolidated_fname.endswith('json'):\n",
    "        iter_id = int(consolidated_fname.split('/')[-2][-1])\n",
    "        # Store the time of the most recent consolidation event as iteration time.\n",
    "        iter_time = datetime.strptime(consolidated_fname.split('/')[-1], '%Y-%m-%d_%H:%M:%S.json')\n",
    "        if iter_id in consolidation_times:\n",
    "            consolidation_times[iter_id] = max(consolidation_times[iter_id], iter_time)\n",
    "        else:\n",
    "            consolidation_times[iter_id] = iter_time\n",
    "        consolidated_cost_times.append(iter_time)\n",
    "                                      \n",
    "        with open(consolidated_fname, 'r') as f:\n",
    "            consolidated_data = json.load(f)\n",
    "        for consolidation in consolidated_data:\n",
    "            obj_id = consolidation['datasetObjectId']\n",
    "            n_boxes = len(consolidation['consolidatedAnnotation']['content'][\n",
    "                'category']['annotations'])\n",
    "            if obj_id not in obj_ids:\n",
    "                obj_ids.add(obj_id)\n",
    "                consolidated_nims[iter_id] += 1            \n",
    "                consolidated_nboxes[iter_id] += n_boxes\n",
    "            \n",
    "total_human_labels = sum(consolidated_nims.values())\n",
    "            \n",
    "# Count the number of machine iterations in each class each iteration.\n",
    "!aws s3 cp {S3_OUTPUT + '/activelearning'} od_output_data/activelearning --recursive --quiet\n",
    "auto_nboxes = defaultdict(int)\n",
    "auto_nims = defaultdict(int)\n",
    "auto_times = {}\n",
    "auto_cost_times = []\n",
    "\n",
    "for auto_fname in glob.glob('od_output_data/activelearning/**', recursive=True):\n",
    "    if auto_fname.endswith('auto_annotator_output.txt'):\n",
    "        iter_id = int(auto_fname.split('/')[-3])\n",
    "        with open(auto_fname, 'r') as f:\n",
    "            annots = [' '.join(l.split()[1:]) for l in f.readlines()]\n",
    "        auto_nims[iter_id] += len(annots)\n",
    "        for annot in annots:\n",
    "            annot = json.loads(annot)\n",
    "            time_str = annot['category-metadata']['creation-date']\n",
    "            auto_time = datetime.strptime(time_str, '%Y-%m-%dT%H:%M:%S.%f')\n",
    "            n_boxes = len(annot['category']['annotations'])\n",
    "            auto_nboxes[iter_id] += n_boxes\n",
    "            if iter_id in auto_times:\n",
    "                auto_times[iter_id] = max(auto_times[iter_id], auto_time)\n",
    "            else:\n",
    "                auto_times[iter_id] = auto_time\n",
    "            auto_cost_times.append(auto_time)\n",
    "                \n",
    "total_auto_labels = sum(auto_nims.values())\n",
    "n_iters = max(len(auto_times), len(consolidation_times))\n",
    "\n",
    "# Get plots for auto-annotation neural-net training.\n",
    "def get_training_job_data(training_job_name):\n",
    "    logclient = boto3.client('logs')\n",
    "    log_group_name = '/aws/sagemaker/TrainingJobs'\n",
    "    log_stream_name = logclient.describe_log_streams(logGroupName=log_group_name,\n",
    "        logStreamNamePrefix=training_job_name)['logStreams'][0]['logStreamName']\n",
    "    train_log = logclient.get_log_events(\n",
    "        logGroupName=log_group_name,\n",
    "        logStreamName=log_stream_name,\n",
    "        startFromHead=True\n",
    "    )\n",
    "    events = train_log['events']\n",
    "    next_token = train_log['nextForwardToken']\n",
    "    while True:\n",
    "        train_log = logclient.get_log_events(\n",
    "            logGroupName=log_group_name,\n",
    "            logStreamName=log_stream_name,\n",
    "            startFromHead=True,\n",
    "            nextToken=next_token\n",
    "        )\n",
    "        if train_log['nextForwardToken'] == next_token:\n",
    "            break\n",
    "        events = events + train_log['events']\n",
    "\n",
    "    mAPs = []\n",
    "    for event in events:\n",
    "        msg = event['message']\n",
    "        if 'Final configuration' in msg:\n",
    "            num_samples = int(msg.split('num_training_samples\\': u\\'')[1].split('\\'')[0])\n",
    "        elif 'validation mAP <score>=(' in msg:\n",
    "            mAPs.append(float(msg.split('validation mAP <score>=(')[1][:-1]))\n",
    "\n",
    "    return num_samples, mAPs\n",
    "\n",
    "training_data = !aws s3 ls {S3_OUTPUT + '/training/'} --recursive\n",
    "training_sizes = []\n",
    "training_mAPs = []\n",
    "training_iters = []\n",
    "for line in training_data:\n",
    "    if line.split('/')[-1] == 'model.tar.gz':\n",
    "        training_job_name = line.split('/')[-3]\n",
    "        n_samples, mAPs = get_training_job_data(training_job_name)\n",
    "        training_sizes.append(n_samples)\n",
    "        training_mAPs.append(mAPs)\n",
    "        training_iters.append(int(line.split('/')[-5]))\n",
    "        \n",
    "plt.figure(facecolor='white', figsize=(14, 5), dpi=100)\n",
    "ax = plt.subplot(131)\n",
    "total_human = 0\n",
    "total_auto = 0\n",
    "for iter_id in range(1, n_iters + 1):\n",
    "    cost_human = consolidated_nims[iter_id] * HUMAN_PRICE\n",
    "    cost_auto = auto_nims[iter_id] * AUTO_PRICE\n",
    "    total_human += cost_human\n",
    "    total_auto += cost_auto\n",
    "    \n",
    "    plt.bar(iter_id, cost_human, width=.8, color='C0',\n",
    "            label='human' if iter_id==1 else None)\n",
    "    plt.bar(iter_id, cost_auto, bottom=cost_human,\n",
    "            width=.8, color='C1', label='auto' if iter_id==1 else None)\n",
    "plt.title('Total annotation costs:\\n\\${:.2f} human, \\${:.2f} auto'.format(\n",
    "    total_human, total_auto))\n",
    "plt.xlabel('Iter')\n",
    "plt.ylabel('Cost in dollars')\n",
    "plt.legend()\n",
    "\n",
    "plt.subplot(132)\n",
    "plt.title('Total annotation counts:\\nHuman: {} ims, {} boxes\\nMachine: {} ims, {} boxes'.format(\n",
    "    sum(consolidated_nims.values()), sum(consolidated_nboxes.values()), sum(auto_nims.values()), sum(auto_nboxes.values())))\n",
    "for iter_id in consolidated_nims.keys():\n",
    "    plt.bar(iter_id, auto_nims[iter_id], color='C1', width=.4, label='ims, auto' if iter_id==1 else None)\n",
    "    plt.bar(iter_id, consolidated_nims[iter_id],\n",
    "            bottom=auto_nims[iter_id], color='C0', width=.4, label='ims, human' if iter_id==1 else None)\n",
    "    plt.bar(iter_id + .4, auto_nboxes[iter_id], color='C1', alpha=.4, width=.4, label='boxes, auto' if iter_id==1 else None)\n",
    "    plt.bar(iter_id + .4, consolidated_nboxes[iter_id],\n",
    "            bottom=auto_nboxes[iter_id], color='C0', width=.4, alpha=.4, label='boxes, human' if iter_id==1 else None)\n",
    "\n",
    "tick_labels_boxes = ['Iter {}, boxes'.format(iter_id + 1) for iter_id in range(n_iters)]\n",
    "tick_labels_images = ['Iter {}, images'.format(iter_id + 1) for iter_id in range(n_iters)]\n",
    "tick_locations_images = np.arange(n_iters) + 1\n",
    "tick_locations_boxes = tick_locations_images + .4\n",
    "tick_labels = np.concatenate([[tick_labels_boxes[idx], tick_labels_images[idx]] for idx in range(n_iters)])\n",
    "tick_locations = np.concatenate([[tick_locations_boxes[idx], tick_locations_images[idx]] for idx in range(n_iters)])\n",
    "plt.xticks(tick_locations, tick_labels, rotation=90)\n",
    "plt.legend()\n",
    "plt.ylabel('Count')\n",
    "\n",
    "if len(training_sizes) > 0:\n",
    "    plt.subplot(133)\n",
    "    plt.title('Active learning training curves')\n",
    "    plt.grid(True)\n",
    "\n",
    "    cmap = plt.get_cmap('coolwarm')\n",
    "    n_all = len(training_sizes)\n",
    "    for iter_id_id, (iter_id, size, mAPs) in enumerate(zip(training_iters, training_sizes, training_mAPs)):\n",
    "        plt.plot(mAPs, label='Iter {}, auto'.format(iter_id + 1), color=cmap(iter_id_id / max(1, (n_all-1))))\n",
    "        plt.legend()\n",
    "\n",
    "    plt.xlabel('Training epoch')\n",
    "    plt.ylabel('Validation mAP')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Ground Truth labeling job results\n",
    "**This section should take about 20 minutes to complete.**\n",
    "\n",
    "Once the job has finished, we can analyze the results. Evaluate the following cell and verify the output is `'Completed'` before continuing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client.describe_labeling_job(LabelingJobName=job_name)['LabelingJobStatus']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The plots in the [Monitor job progress](#Monitor-job-progress) section form part of the analysis. In this section, we will gain additional insights into the results, which are contained in the output manifest. You can find the location of the output manifest under `AWS Console > SageMaker > Labeling Jobs > [name of your job]`. We will obtain it programmatically in the cell below.\n",
    "\n",
    "## Postprocess the output manifest\n",
    "Now that the job is complete, we will download the output manifest manfiest and postprocess it to create a list of `output_images` with the results. Each entry in the list will be a `BoxedImage` object that contains information about the image and the bounding boxes created by the labeling jobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the output manifest's annotations.\n",
    "OUTPUT_MANIFEST = 's3://{}/{}/output/{}/manifests/output/output.manifest'.format(BUCKET, EXP_NAME, job_name)\n",
    "\n",
    "!aws s3 cp {OUTPUT_MANIFEST} 'output.manifest'\n",
    "\n",
    "with open('output.manifest', 'r') as f:\n",
    "    output = [json.loads(line.strip()) for line in f.readlines()]\n",
    "    \n",
    "# Retrieve the worker annotations.\n",
    "!aws s3 cp {S3_OUTPUT + '/annotations/worker-response'} od_output_data/worker-response --recursive --quiet\n",
    "\n",
    "# Find the worker files.\n",
    "worker_file_names = glob.glob(\n",
    "    'od_output_data/worker-response/**/*.json', recursive=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ground_truth_od import BoundingBox, WorkerBoundingBox, \\\n",
    "    GroundTruthBox, BoxedImage\n",
    "\n",
    "# Create data arrays.\n",
    "confidences = np.zeros(len(output))\n",
    "\n",
    "# Find the job name the manifest corresponds to.\n",
    "keys = list(output[0].keys())\n",
    "metakey = keys[np.where([('-metadata' in k) for k in keys])[0][0]]\n",
    "jobname = metakey[:-9]\n",
    "output_images = []\n",
    "consolidated_boxes = []\n",
    "\n",
    "# Extract the data.\n",
    "for datum_id, datum in enumerate(output):\n",
    "    image_size = datum['category']['image_size'][0]\n",
    "    box_annotations = datum['category']['annotations']\n",
    "    uri = datum['source-ref']\n",
    "    box_confidences = datum[metakey]['objects']\n",
    "    human = int(datum[metakey]['human-annotated'] == 'yes')\n",
    "\n",
    "    # Make image object.\n",
    "    image = BoxedImage(id=datum_id, size=image_size,\n",
    "                       uri=uri)\n",
    "\n",
    "    # Create bounding boxes for image.\n",
    "    boxes = []\n",
    "    for i, annotation in enumerate(box_annotations):\n",
    "        box = BoundingBox(image_id=datum_id, boxdata=annotation)\n",
    "        box.confidence = box_confidences[i]['confidence']\n",
    "        box.image = image\n",
    "        box.human = human\n",
    "        boxes.append(box)\n",
    "        consolidated_boxes.append(box)\n",
    "    image.consolidated_boxes = boxes\n",
    "\n",
    "    # Store if the image is human labeled.\n",
    "    image.human = human\n",
    "\n",
    "    # Retrieve ground truth boxes for the image.\n",
    "    oid_boxes_data = fids2bbs[image.oid_id]\n",
    "    gt_boxes = []\n",
    "    for data in oid_boxes_data:\n",
    "        gt_box = GroundTruthBox(image_id=datum_id, oiddata=data,\n",
    "                                image=image)\n",
    "        gt_boxes.append(gt_box)\n",
    "    image.gt_boxes = gt_boxes\n",
    "\n",
    "    output_images.append(image)\n",
    "\n",
    "# Iterate through the json files, creating bounding box objects.\n",
    "for wfn in worker_file_names:\n",
    "    image_id = int(wfn.split('/')[-2])\n",
    "    image = output_images[image_id]\n",
    "    with open(wfn, \"r\") as worker_file:\n",
    "        annotation = json.load(worker_file)\n",
    "        answers = annotation['answers']\n",
    "        for answer in answers:\n",
    "            wid = answer['workerId']\n",
    "            wboxes_data = \\\n",
    "                answer['answerContent']['boundingBox']['boundingBoxes']\n",
    "            for boxdata in (wboxes_data or []):\n",
    "                box = WorkerBoundingBox(image_id=image_id,\n",
    "                                        worker_id=wid,\n",
    "                                        boxdata=boxdata)\n",
    "                box.image = image\n",
    "                image.worker_boxes.append(box)\n",
    "\n",
    "# Get the human- and auto-labeled images.\n",
    "human_labeled = [img for img in output_images if img.human]\n",
    "auto_labeled = [img for img in output_images if not img.human]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot annotated images\n",
    "In any data science task, it is crucial to plot and inspect the results to check they make sense. In order to do this, we will \n",
    "1. Download the input images that Ground Truth annotated.\n",
    "2. Separate images annotated by humans from those annoted via the auto-labeling mechanism.\n",
    "3. Plot images in the human/auto-annotated classes.\n",
    "\n",
    "We will download the input images to a `LOCAL_IMAGE_DIR` you can choose in the next cell. Note that if this directory already contains images with the same filenames as your Ground Truth input images, we will not re-download the images.\n",
    "\n",
    "If your dataset is large and you do not wish to download and plot **all** the images, simply set `DATASET_SIZE` to a small number. We will pick a random subset of your data for plotting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOCAL_IMG_DIR = '<< choose a local directory name to download the images to >>' # Replace with the name of a local directory to store images.\n",
    "assert LOCAL_IMG_DIR != '<< choose a local directory name to download the images to >>', 'Please provide a local directory name'\n",
    "DATASET_SIZE = len(output_images) # Change this to a reasonable number if your dataset is larger than 10K images.\n",
    "\n",
    "image_subset = np.random.choice(output_images, DATASET_SIZE, replace=False)\n",
    "\n",
    "for img in image_subset:\n",
    "    target_fname = os.path.join(\n",
    "        LOCAL_IMG_DIR, img.uri.split('/')[-1])\n",
    "    if not os.path.isfile(target_fname):\n",
    "        !aws s3 cp {img.uri} {target_fname}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot a small output sample to understand the labeling mechanism\n",
    "The following cell will create two figures. The first plots `N_SHOW` images as annotated by humans. The first column shows the original bounding boxes produced by the human labelers working on Amazon Mechanical Turk. The second column shows the result of combining these boxes to produce a consolidated label, which is the final output of Ground Truth for the human-labeled images. Finally, the third column shows the \"true\" bounding boxes according to the Open Images Dataset for reference.\n",
    "\n",
    "The second plots `N_SHOW` images as annotated by the auto-labeling mechanism. In this case, there is no consolidation phase, so only the auto-labeled image and the \"true\" label are displayed.\n",
    "\n",
    "By default, `N_SHOW = 5`, but feel free to change this to any small number."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_SHOW = 5\n",
    "\n",
    "# Find human and auto-labeled images in the subset.\n",
    "human_labeled_subset = [img for img in image_subset if img.human]\n",
    "auto_labeled_subset = [img for img in image_subset if not img.human]\n",
    "\n",
    "# Show examples of each\n",
    "fig, axes = plt.subplots(N_SHOW, 3, figsize=(9, 2*N_SHOW),\n",
    "                         facecolor='white', dpi=100)\n",
    "fig.suptitle('Human-labeled examples', fontsize=24)\n",
    "axes[0, 0].set_title('Worker labels', fontsize=14)\n",
    "axes[0, 1].set_title('Consolidated label', fontsize=14)\n",
    "axes[0, 2].set_title('True label', fontsize=14)\n",
    "for row, img in enumerate(np.random.choice(human_labeled_subset, size=N_SHOW)):\n",
    "    img.download(LOCAL_IMG_DIR)\n",
    "    img.plot_worker_bbs(axes[row, 0])\n",
    "    img.plot_consolidated_bbs(axes[row, 1])\n",
    "    img.plot_gt_bbs(axes[row, 2])\n",
    "\n",
    "if auto_labeled_subset:\n",
    "    fig, axes = plt.subplots(N_SHOW, 2, figsize=(6, 2*N_SHOW),\n",
    "                             facecolor='white', dpi=100)\n",
    "    fig.suptitle('Auto-labeled examples', fontsize=24)\n",
    "    axes[0, 0].set_title('Auto-label', fontsize=14)\n",
    "    axes[0, 1].set_title('True label', fontsize=14)\n",
    "    for row, img in enumerate(np.random.choice(auto_labeled_subset, size=N_SHOW)):\n",
    "        img.download(LOCAL_IMG_DIR)\n",
    "        img.plot_consolidated_bbs(axes[row, 0])\n",
    "        img.plot_gt_bbs(axes[row, 1])\n",
    "else:\n",
    "    print(\"No images were auto-labeled.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the resulting bounding boxes to a pdf\n",
    "Finally, we plot the results to two large pdf files. You can adjust the number of `rows_per_page` and `columns_per_page` if you would like. With the default settings, the pdfs will display 25 images per page. Each page will contain images annotated either by human annotators or by the auto-labeling mechanism. The first, `ground-truth-od-confidence.pdf`, contains images sorted by the confidence Ground Truth has in its prediction. The second, `ground-truth-od-miou.pdf`, contains the same images, but sorted by the quality of the annotations compared to the standard labels from the Open Images Dataset. See  the [Compare Ground Truth results to standard labels](#Compare-Ground-Truth-results-to-standard-labels) section for more details.\n",
    "\n",
    "We will only plot 10 each of the human- and auto-annotated images. You can set `N_SHOW` to another number if you want to only plot more of the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''Create pdfs with images sorted by miou and confidence.'''\n",
    "\n",
    "N_SHOW = 10\n",
    "\n",
    "# Created, sort list of imgs and mious.\n",
    "h_img_mious = [(img, img.compute_iou_bb()) for img in human_labeled]\n",
    "a_img_mious = [(img, img.compute_iou_bb()) for img in auto_labeled]\n",
    "h_img_mious.sort(key=lambda x: x[1], reverse=True)\n",
    "a_img_mious.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "# Create, sort the images by confidence.\n",
    "h_img_confs = [(img, img.compute_img_confidence()) for img in human_labeled]\n",
    "a_img_confs = [(img, img.compute_img_confidence()) for img in auto_labeled]\n",
    "h_img_confs.sort(key=lambda x: x[1], reverse=True)\n",
    "a_img_confs.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "# Define number of rows, columns per page.\n",
    "rows_per_page = 5\n",
    "columns_per_page = 5\n",
    "n_per_page = rows_per_page * columns_per_page\n",
    "\n",
    "\n",
    "def title_page(title):\n",
    "    '''Create a page with only text.'''\n",
    "    plt.figure(figsize=(10, 10), facecolor='white', dpi=100)\n",
    "    plt.text(0.1, 0.5, s=title, fontsize=20)\n",
    "    plt.axis('off')\n",
    "    pdf.savefig()\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def page_loop(mious, axes, worker=False):\n",
    "    '''Loop over a single image page of the output pdf.'''\n",
    "    for i, row in enumerate(axes):\n",
    "        for j, ax in enumerate(row):\n",
    "            img_idx = n_per_page*page + rows_per_page*i + j\n",
    "\n",
    "            # Break out of loop if all the images are plotted.\n",
    "            if img_idx >= min(N_SHOW, len(mious)):\n",
    "                return\n",
    "\n",
    "            img, miou = mious[img_idx]\n",
    "            img.download(LOCAL_IMG_DIR)\n",
    "            if worker:\n",
    "                img.plot_worker_bbs(\n",
    "                    ax, img_kwargs={'aspect': 'auto'},\n",
    "                    box_kwargs={'lw': .5})\n",
    "            else:\n",
    "                img.plot_gt_bbs(\n",
    "                    ax, img_kwargs={'aspect': 'auto'},\n",
    "                    box_kwargs={'edgecolor': 'C2', 'lw': .5})\n",
    "                img.plot_consolidated_bbs(\n",
    "                    ax, img_kwargs={'aspect': 'auto'},\n",
    "                    box_kwargs={'edgecolor': 'C1', 'lw': .5})\n",
    "\n",
    "\n",
    "# Create pdfs for the images sorted by confidence and by mIoU.\n",
    "mode_metrics = (('mIoU', (('Worker', h_img_mious),\n",
    "                          ('Consolidated human', h_img_mious),\n",
    "                          ('Auto', a_img_mious))),\n",
    "                ('confidence', (('Worker', h_img_confs),\n",
    "                                ('Consolidated human', h_img_confs),\n",
    "                                ('Auto', a_img_confs))))\n",
    "\n",
    "for mode, labels_metrics in mode_metrics:\n",
    "    pdfname = f'ground-truth-od-{mode}.pdf'\n",
    "    with PdfPages(pdfname) as pdf:\n",
    "        title_page('Images labeled by SageMaker Ground Truth\\n'\n",
    "                   f'and sorted by {mode}')\n",
    "\n",
    "        print(f'Plotting images sorted by {mode}...')\n",
    "\n",
    "        # Show human- and auto-labeled images.\n",
    "        for label, metrics in labels_metrics:\n",
    "            worker = (label == 'Worker')\n",
    "            if worker:\n",
    "                title_page('Original worker labels')\n",
    "            else:\n",
    "                title_page(\n",
    "                    f'{label} labels in orange,\\n'\n",
    "                    'Open Image annotations in green')\n",
    "            n_images = min(len(metrics), N_SHOW)\n",
    "            n_pages = (n_images-1)//n_per_page + 1\n",
    "\n",
    "            print(f'Plotting {label.lower()}-labeled images...')\n",
    "            for page in range(n_pages):\n",
    "                print(f'{page*n_per_page}/{n_images}')\n",
    "                fig, axes = plt.subplots(\n",
    "                    rows_per_page, columns_per_page, dpi=125)\n",
    "                page_loop(metrics, axes, worker=worker)\n",
    "                for ax in axes.ravel():\n",
    "                    ax.axis('off')\n",
    "\n",
    "                # Find the max/min mIoU or confidence on each page.\n",
    "                metrics_page = metrics[page*n_per_page:min((page+1)*n_per_page,\n",
    "                                                           n_images)]\n",
    "                max_metric = metrics_page[0][1]\n",
    "                min_metric = metrics_page[-1][1]\n",
    "                fig.suptitle(\n",
    "                    f'{mode} range: [{max_metric:1.3f}, {min_metric:1.3f}]')\n",
    "                pdf.savefig()\n",
    "                plt.close()\n",
    "\n",
    "print('Done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compare Ground Truth results to standard labels\n",
    "\n",
    "**This section should take about 5 minutes to complete.**\n",
    "\n",
    "Sometimes we have an alternative set of data labels available. \n",
    "For example, the Open Images data has already been carefully annotated by a professional annotation workforce.\n",
    "This allows us to perform additional analysis that compares Ground Truth labels to the standard labels.\n",
    "When doing so, it is important to bear in mind that any image labels created by humans\n",
    "will most likely not be 100% accurate. For this reason, it is better to think of labeling accuracy as\n",
    "\"adherence to a particular standard / set of labels\" rather than \"how good (in absolute terms) are the Ground Truth labels.\"\n",
    "\n",
    "## Compute mIoUs for images in the dataset\n",
    "The following cell plots a histogram of the mean intersections-over-unions (mIoUs) between labels produced by Ground Truth and reference labels from the Open Images Dataset. The intersection over union, also known as the [Jaccard index](https://en.wikipedia.org/wiki/Jaccard_index), of two bounding boxes is a measure of their similarity. Because each image can contain multiple bounding boxes, we take the mean of the IoUs to measure the success of the labeling for that image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''Plot the distribution of mIoUs by image in the dataset.'''\n",
    "h_mious = np.array([h_img_miou[1] for h_img_miou in h_img_mious])\n",
    "a_mious = np.array([a_img_miou[1] for a_img_miou in a_img_mious])\n",
    "xvals = np.linspace(0, 1, 17)\n",
    "xticks = np.linspace(0, 1, 5)\n",
    "\n",
    "plt.figure(figsize=(12, 5), dpi=300, facecolor='white')\n",
    "plt.hist([h_mious, a_mious], rwidth=.8, edgecolor='k',\n",
    "         bins=xvals, label=['Human', 'Auto'])\n",
    "plt.xticks(xticks)\n",
    "plt.title(f'{len(h_mious)} human-labeled images with mIoU {np.mean(h_mious):.2f}\\n{len(a_mious)} auto-labeled images with mIoU {np.mean(a_mious):.2f}')\n",
    "plt.ylabel('Number of images')\n",
    "plt.xlabel('mIoU')\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the results\n",
    "It is useful to see what corresponds to a good or bad mIoU in practice. The following cell displays images with the highest and lowest mIoUs vs the standard labels for both the human- and auto-labeled images. As before, the Ground Truth bounding boxes are in blue and the standard boxes are in lime green.\n",
    "\n",
    "In our example run, the images with the lowest mIoUs demonstrated that Ground Truth can sometimes outperform standard labels. In particular, many of the standard labels for this dataset contain only one large bounding box despite the presence of many small objects in the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sort the images by mIoU.\n",
    "h_img_mious.sort(key=lambda x: x[1], reverse=True)\n",
    "a_img_mious.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "\n",
    "# Plot images and mIoUs for human- vs auto-labeling.\n",
    "if a_img_mious:\n",
    "    labels = ('Human', 'Auto')\n",
    "    both_img_mious = (h_img_mious, a_img_mious)\n",
    "else:\n",
    "    labels = ('Human',)\n",
    "    both_img_mious = (h_img_mious,)\n",
    "for label, all_img_mious in zip(labels, both_img_mious):\n",
    "\n",
    "    # Do the highest and lowest mious\n",
    "    tb_img_mious = (all_img_mious[:6], all_img_mious[-6:])\n",
    "    titles = ('highest', 'lowest')\n",
    "    for img_mious, title in zip(tb_img_mious, titles):\n",
    "\n",
    "        # Make a figure with six images.\n",
    "        fig, axes = plt.subplots(\n",
    "            2, 3, figsize=(12, 4), dpi=100, facecolor='white')\n",
    "        for (img, miou), ax in zip(img_mious, axes.ravel()):\n",
    "            img.download(LOCAL_IMG_DIR)\n",
    "            img.plot_consolidated_bbs(\n",
    "                ax, box_kwargs={'lw': 1.5, 'color': 'blue'})\n",
    "            img.plot_gt_bbs(ax, box_kwargs={'lw': 1, 'color': 'lime'})\n",
    "            ax.set_title(f\"mIoU: {miou:1.3f}\")\n",
    "            ax.axis('off')\n",
    "        fig.suptitle(\n",
    "            f'{label}-labeled images with the {title} mIoUs', fontsize=16)\n",
    "        fig.tight_layout(rect=[0, 0, 1, .95])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Understand the relationship between confidence and annotation quality\n",
    "\n",
    "During both human- and auto-labeling, Ground Truth produces confidence scores associated with its labels. These scores are used internally by Ground Truth in various ways. As an example, the auto-labeling mechanism will only ouput an annotation for an image when the confidence passes a dynamically-generated threshold.\n",
    "\n",
    "In practice, Ground Truth is often used to annotate entirely new datasets for which there are no standard labels. The following cells show how the confidence acts as a proxy for the true quality of the annotations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''Plot the mIoUs vs the confidences.'''\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from ground_truth_od import group_miou\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(\n",
    "    1, 2, dpi=100, facecolor='white', figsize=(12, 5))\n",
    "\n",
    "if RUN_FULL_AL_DEMO:\n",
    "    label_confs_colors = (('Human', h_img_confs, 'C0'),\n",
    "                          ('Auto', a_img_confs, 'C1'))\n",
    "else:\n",
    "    label_confs_colors = (('Human', h_img_confs, 'C0'),)\n",
    "\n",
    "\n",
    "ax1.set_title('mIoU vs confidence with regression lines')\n",
    "ax1.set_xlabel('Confidence')\n",
    "ax1.set_ylabel('mIoU')\n",
    "for label, img_confs, color in label_confs_colors:\n",
    "    confs = [img_conf[1] for img_conf in img_confs]\n",
    "    mious = [img_conf[0].compute_iou_bb() for img_conf in img_confs]\n",
    "\n",
    "    # Compute regression line.\n",
    "    slope, intercept, *_ = stats.linregress(confs, mious)\n",
    "    xs = np.array((0, 1))\n",
    "\n",
    "    # Plot points and line.\n",
    "    ax1.plot(confs, mious, '.', label=label, color=color)\n",
    "    ax1.plot(xs, slope * xs + intercept, color=color, lw=3)\n",
    "\n",
    "ax1.set_xlim([-0.05, 1.05])\n",
    "ax1.set_ylim([-0.05, 1.05])\n",
    "ax1.legend()\n",
    "\n",
    "\n",
    "# Compute the mIoU of subsets of the images based on confidence level.\n",
    "if RUN_FULL_AL_DEMO:\n",
    "    labels_imgs = (('Human', human_labeled),\n",
    "                   ('Auto', auto_labeled))\n",
    "else:\n",
    "    labels_imgs = (('Human', human_labeled),)\n",
    "\n",
    "deciles = np.linspace(0, .9, 10)\n",
    "\n",
    "mious_deciles = {}\n",
    "for label, imgs in labels_imgs:\n",
    "    # Find thresholds of confidences for deciles.\n",
    "    confs = np.array([img.compute_img_confidence() for img in imgs])\n",
    "    thresholds = pd.Series(confs).quantile(deciles)\n",
    "\n",
    "    # Select images with confidence greater than thresholds.\n",
    "    mious = []\n",
    "    for decile in deciles:\n",
    "        img_subset = [img for img in imgs\n",
    "                      if img.compute_img_confidence() > thresholds[decile]]\n",
    "\n",
    "        # Compute mious.\n",
    "        mious.append(group_miou(img_subset))\n",
    "\n",
    "    # Save the results.\n",
    "    mious_deciles[label] = mious\n",
    "\n",
    "    # Create pots\n",
    "    ax2.plot(100-deciles*100, mious, label=label)\n",
    "    ax2.set_ylabel('mIoU')\n",
    "    ax2.set_title('Effect of increasing confidence thresholds')\n",
    "    ax2.set_xlabel('Top x% of images by confidence')\n",
    "    ax2.set_xlim([105, 5])\n",
    "    ax2.set_xticks(np.linspace(100, 10, 10))\n",
    "    ax2.legend()\n",
    "ax2.grid()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once again, it is good to see some concrete examples. The next cell displays several of the human- and auto-labeled images with the highest confidence scores across the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''Plot the images with the highest confidences.'''\n",
    "\n",
    "# Sort the images by confidence.\n",
    "h_img_confs = [(img, img.compute_img_confidence()) for img in human_labeled]\n",
    "a_img_confs = [(img, img.compute_img_confidence()) for img in auto_labeled]\n",
    "h_img_confs.sort(key=lambda x: x[1], reverse=True)\n",
    "a_img_confs.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "# Do both the human- and auto-labeled images.\n",
    "label_confs = (('human', h_img_confs),\n",
    "               ('auto', a_img_confs))\n",
    "for label, img_confs in label_confs:\n",
    "    plt.figure(facecolor='white', figsize=(15, 4), dpi=100)\n",
    "    plt.suptitle(\n",
    "        f'Top-5 confidence {label}-labels (orange) and corresponding '\n",
    "        'Open Images annotations (green)')\n",
    "    for img_id, (img, conf) in enumerate(img_confs[:5]):\n",
    "        img.download(LOCAL_IMG_DIR)\n",
    "        ax = plt.subplot(1, 5, img_id + 1)\n",
    "        img.plot_gt_bbs(ax, box_kwargs={'edgecolor': 'C2', 'lw': 3})\n",
    "        img.plot_consolidated_bbs(ax, box_kwargs={'edgecolor': 'C1', 'lw': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('mIoU for the whole dataset: ', group_miou(output_images))\n",
    "print('mIoU for human-labeled images: ', group_miou(human_labeled))\n",
    "print('mIoU for auto-labeled images: ', group_miou(auto_labeled))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### See how the number of objects in the image affects confidence\n",
    "\n",
    "The next cell produces two subplots:\n",
    "* The left subplot shows the counts of images with different numbers of objects in the image on a log scale. Notice that humans are assigned to label more of the images with many boxes.\n",
    "\n",
    "* The right subplot shows how the confidence associated with an image decreases as the number of objects in the image increases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the number of boxes per image and create a histogram.\n",
    "nboxes_human = np.array([img.n_consolidated_boxes()\n",
    "                         for img in human_labeled])\n",
    "nboxes_auto = np.array([img.n_consolidated_boxes()\n",
    "                        for img in auto_labeled])\n",
    "\n",
    "max_boxes = max(nboxes_auto.max() if nboxes_auto.size != 0 else 0,\n",
    "                nboxes_human.max() if nboxes_human.size != 0 else 0)\n",
    "n_boxes = np.arange(0, max_boxes+2)\n",
    "\n",
    "# Find mean confidences by number of boxes.\n",
    "h_confs_by_n = []\n",
    "a_confs_by_n = []\n",
    "# Do human and auto.\n",
    "for labeled, mean_confs in ((human_labeled, h_confs_by_n),\n",
    "                            (auto_labeled, a_confs_by_n)):\n",
    "    for n_box in n_boxes:\n",
    "        h_img_n = [img for img in labeled\n",
    "                   if img.n_consolidated_boxes() == n_box]\n",
    "        mean_conf = np.mean([img.compute_img_confidence()\n",
    "                             for img in h_img_n])\n",
    "        mean_confs.append(mean_conf)\n",
    "\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(\n",
    "    1, 2, figsize=(14, 4), facecolor='white', dpi=100)\n",
    "ax1.hist([nboxes_human, nboxes_auto], n_boxes,\n",
    "         label=['Human', 'Auto'], align='left')\n",
    "ax1.set_xlabel('Bounding boxes in image')\n",
    "ax1.set_title('Image counts vs number of bounding boxes')\n",
    "ax1.set_yscale('log')\n",
    "ax1.set_ylabel('Number of images')\n",
    "ax1.legend();\n",
    "\n",
    "# Find where we have nonzero box counts.\n",
    "h_not_nan = np.logical_not(np.isnan(h_confs_by_n))\n",
    "a_not_nan = np.logical_not(np.isnan(a_confs_by_n))\n",
    "\n",
    "# Plot.\n",
    "ax2.set_title('Image confidences vs number of bounding boxes')\n",
    "ax2.plot(n_boxes[h_not_nan], np.array(h_confs_by_n)[h_not_nan], 'D',\n",
    "         color='C0', label='Human')\n",
    "ax2.plot(n_boxes[a_not_nan], np.array(a_confs_by_n)[a_not_nan], 'D',\n",
    "         color='C1', label='Auto')\n",
    "ax2.set_xlabel('Bounding boxes in image')\n",
    "ax2.set_ylabel('Mean image confidence')\n",
    "ax2.legend();\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train an object detection model using Ground Truth labels\n",
    "At this stage, we have fully labeled our dataset and we can train a machine learning model to perform object detection. We'll do so using the **augmented manifest** output of our labeling job - no additional file translation or manipulation required! For a more complete description of the augmented manifest, see our other [example notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/ground_truth_labeling_jobs/object_detection_augmented_manifest_training/object_detection_augmented_manifest_training.ipynb).\n",
    "\n",
    "**NOTE:** Object detection is a complex task, and training neural networks to high accuracy requires large datasets and careful hyperparameter tuning. The following cells illustrate how to train a neural network using a Ground Truth output augmented manifest, and how to interpret the results. However, we shouldn't expect a network trained on 100 or 1000 images to do a phenomenal job on unseen images!\n",
    "\n",
    "First, we'll split our augmented manifest into a training set and a validation set using an 80/20 split and save the results to files that the model will use during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('output.manifest', 'r') as f:\n",
    "    output = [json.loads(line) for line in f.readlines()]\n",
    "\n",
    "# Shuffle output in place.\n",
    "np.random.shuffle(output)\n",
    "    \n",
    "dataset_size = len(output)\n",
    "train_test_split_index = round(dataset_size*0.8)\n",
    "\n",
    "train_data = output[:train_test_split_index]\n",
    "validation_data = output[train_test_split_index:]\n",
    "\n",
    "num_training_samples = 0\n",
    "with open('train.manifest', 'w') as f:\n",
    "    for line in train_data:\n",
    "        f.write(json.dumps(line))\n",
    "        f.write('\\n')\n",
    "        num_training_samples += 1\n",
    "    \n",
    "with open('validation.manifest', 'w') as f:\n",
    "    for line in validation_data:\n",
    "        f.write(json.dumps(line))\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll upload these manifest files to the previously defined S3 bucket so that they can be used in the training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp train.manifest s3://{BUCKET}/{EXP_NAME}/train.manifest\n",
    "!aws s3 cp validation.manifest s3://{BUCKET}/{EXP_NAME}/validation.manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Here we define S3 file paths for input and output data, the training image containing the object detection algorithm, and instantiate a SageMaker session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from sagemaker import get_execution_role\n",
    "from time import gmtime, strftime\n",
    "\n",
    "role = get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "s3 = boto3.resource('s3')\n",
    "\n",
    "training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'object-detection', repo_version='latest')\n",
    "augmented_manifest_filename_train = 'train.manifest'\n",
    "augmented_manifest_filename_validation = 'validation.manifest'\n",
    "bucket_name = BUCKET\n",
    "s3_prefix = EXP_NAME\n",
    "s3_output_path = 's3://{}/groundtruth-od-augmented-manifest-output'.format(bucket_name) # Replace with your desired output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defines paths for use in the training job request.\n",
    "s3_train_data_path = 's3://{}/{}/{}'.format(bucket_name, s3_prefix, augmented_manifest_filename_train)\n",
    "s3_validation_data_path = 's3://{}/{}/{}'.format(bucket_name, s3_prefix, augmented_manifest_filename_validation)\n",
    "\n",
    "print(\"Augmented manifest for training data: {}\".format(s3_train_data_path))\n",
    "print(\"Augmented manifest for validation data: {}\".format(s3_validation_data_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_manifest_s3_key = s3_train_data_path.split(bucket_name)[1][1:]\n",
    "s3_obj = s3.Object(bucket_name, augmented_manifest_s3_key)\n",
    "augmented_manifest = s3_obj.get()['Body'].read().decode('utf-8')\n",
    "augmented_manifest_lines = augmented_manifest.split('\\n')\n",
    "num_training_samples = len(augmented_manifest_lines) # Compute number of training samples for use in training job request.\n",
    "\n",
    "# Determine the keys in the training manifest and exclude the meta data from the labling job.\n",
    "attribute_names = list(json.loads(augmented_manifest_lines[0]).keys())\n",
    "attribute_names = [attrib for attrib in attribute_names if 'meta' not in attrib]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    if attribute_names == [\"source-ref\", \"XXXX\"]:\n",
    "        raise Exception(\"The 'attribute_names' variable is set to default values. Please check your augmented manifest file for the label attribute name and set the 'attribute_names' variable accordingly.\")\n",
    "except NameError:\n",
    "    raise Exception(\"The attribute_names variable is not defined. Please check your augmented manifest file for the label attribute name and set the 'attribute_names' variable accordingly.\")\n",
    "\n",
    "# Create unique job name\n",
    "job_name_prefix = 'ground-truthod-demo'\n",
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "model_job_name = job_name_prefix + timestamp\n",
    "\n",
    "training_params = \\\n",
    "    {\n",
    "        \"AlgorithmSpecification\": {\n",
    "            # NB. This is one of the named constants defined in the first cell.\n",
    "            \"TrainingImage\": training_image,\n",
    "            \"TrainingInputMode\": \"Pipe\"\n",
    "        },\n",
    "        \"RoleArn\": role,\n",
    "        \"OutputDataConfig\": {\n",
    "            \"S3OutputPath\": s3_output_path\n",
    "        },\n",
    "        \"ResourceConfig\": {\n",
    "            \"InstanceCount\": 1,\n",
    "            \"InstanceType\": \"ml.p3.2xlarge\",\n",
    "            \"VolumeSizeInGB\": 50\n",
    "        },\n",
    "        \"TrainingJobName\": model_job_name,\n",
    "        \"HyperParameters\": {  # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.\n",
    "            \"base_network\": \"resnet-50\",\n",
    "            \"use_pretrained_model\": \"1\",\n",
    "            \"num_classes\": \"1\",\n",
    "            \"mini_batch_size\": \"1\",\n",
    "            \"epochs\": \"30\",\n",
    "            \"learning_rate\": \"0.001\",\n",
    "            \"lr_scheduler_step\": \"\",\n",
    "            \"lr_scheduler_factor\": \"0.1\",\n",
    "            \"optimizer\": \"sgd\",\n",
    "            \"momentum\": \"0.9\",\n",
    "            \"weight_decay\": \"0.0005\",\n",
    "            \"overlap_threshold\": \"0.5\",\n",
    "            \"nms_threshold\": \"0.45\",\n",
    "            \"image_shape\": \"300\",\n",
    "            \"label_width\": \"350\",\n",
    "            \"num_training_samples\": str(num_training_samples)\n",
    "        },\n",
    "        \"StoppingCondition\": {\n",
    "            \"MaxRuntimeInSeconds\": 86400\n",
    "        },\n",
    "        \"InputDataConfig\": [\n",
    "            {\n",
    "                \"ChannelName\": \"train\",\n",
    "                \"DataSource\": {\n",
    "                    \"S3DataSource\": {\n",
    "                        \"S3DataType\": \"AugmentedManifestFile\",  # NB. Augmented Manifest\n",
    "                        \"S3Uri\": s3_train_data_path,\n",
    "                        \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                        # NB. This must correspond to the JSON field names in your augmented manifest.\n",
    "                        \"AttributeNames\": attribute_names\n",
    "                    }\n",
    "                },\n",
    "                \"ContentType\": \"application/x-recordio\",\n",
    "                \"RecordWrapperType\": \"RecordIO\",\n",
    "                \"CompressionType\": \"None\"\n",
    "            },\n",
    "            {\n",
    "                \"ChannelName\": \"validation\",\n",
    "                \"DataSource\": {\n",
    "                    \"S3DataSource\": {\n",
    "                        \"S3DataType\": \"AugmentedManifestFile\",  # NB. Augmented Manifest\n",
    "                        \"S3Uri\": s3_validation_data_path,\n",
    "                        \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                        # NB. This must correspond to the JSON field names in your augmented manifest.\n",
    "                        \"AttributeNames\": attribute_names\n",
    "                    }\n",
    "                },\n",
    "                \"ContentType\": \"application/x-recordio\",\n",
    "                \"RecordWrapperType\": \"RecordIO\",\n",
    "                \"CompressionType\": \"None\"\n",
    "            }\n",
    "        ]\n",
    "    }\n",
    "\n",
    "print('Training job name: {}'.format(model_job_name))\n",
    "print('\\nInput Data Location: {}'.format(\n",
    "    training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we create the SageMaker training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = boto3.client(service_name='sagemaker')\n",
    "client.create_training_job(**training_params)\n",
    "\n",
    "# Confirm that the training job has started\n",
    "status = client.describe_training_job(TrainingJobName=model_job_name)['TrainingJobStatus']\n",
    "print('Training job current status: {}'.format(status))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check the progess of the training job, you can repeatedly evaluate the following cell. When the training job status reads `'Completed'`, move on to the next part of the tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = boto3.client(service_name='sagemaker')\n",
    "print(\"Training job status: \", client.describe_training_job(TrainingJobName=model_job_name)['TrainingJobStatus'])\n",
    "print(\"Secondary status: \", client.describe_training_job(TrainingJobName=model_job_name)['SecondaryStatus'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_info = client.describe_training_job(TrainingJobName=model_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy the Model \n",
    "\n",
    "Now that we've fully labeled our dataset and have a trained model, we want to use the model to perform inference.\n",
    "\n",
    "Object detection only supports encoded .jpg and .png image formats as inference input for now. The output is in JSON format, or in JSON Lines format for batch transform.\n",
    "\n",
    "This section involves several steps:\n",
    "1. Create Model: Create model for the training output\n",
    "2. Batch Transform: Create a transform job to perform batch inference.\n",
    "3. Host the model for realtime inference: Create an inference endpoint and perform realtime inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "model_name='groundtruth-demo-od-model' + timestamp\n",
    "print(model_name)\n",
    "model_data = training_info['ModelArtifacts']['S3ModelArtifacts']\n",
    "print(model_data)\n",
    "\n",
    "primary_container = {\n",
    "    'Image': training_image,\n",
    "    'ModelDataUrl': model_data,\n",
    "}\n",
    "\n",
    "create_model_response = sagemaker_client.create_model(\n",
    "    ModelName = model_name,\n",
    "    ExecutionRoleArn = role,\n",
    "    PrimaryContainer = primary_container)\n",
    "\n",
    "print(create_model_response['ModelArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Transform\n",
    "We now create a SageMaker Batch Transform job using the model created above to perform batch prediction.\n",
    "\n",
    "### Download Test Data\n",
    "First, let's download a test image that has been held out from the training and validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find a bird not in the images labeled by Ground Truth.\n",
    "img_ids = {img.filename.split('.')[0] for img in output_images}\n",
    "with open('test-annotations-bbox.csv', 'r') as f:\n",
    "    for line in f.readlines()[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        img_id, _, cls_id, conf, xmin, xmax, ymin, ymax, *_ = line\n",
    "        if img_id in skip_these_images:\n",
    "            continue\n",
    "        if cls_id in good_subclasses:\n",
    "            # Skip the first several images\n",
    "            if str(img_id) not in img_ids:\n",
    "                test_bird = img_id\n",
    "                break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "\n",
    "test_image =  test_bird + '.jpg'\n",
    "os.system(f'wget https://s3.amazonaws.com/open-images-dataset/test/{test_image}')\n",
    "Image(test_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_input = 's3://{}/{}/test/'.format(BUCKET, EXP_NAME)\n",
    "print(test_image)\n",
    "\n",
    "!aws s3 cp $test_image $batch_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "batch_job_name = \"object-detection-model\" + timestamp\n",
    "request = \\\n",
    "{\n",
    "    \"TransformJobName\": batch_job_name,\n",
    "    \"ModelName\": model_name,\n",
    "    \"MaxConcurrentTransforms\": 16,\n",
    "    \"MaxPayloadInMB\": 6,\n",
    "    \"BatchStrategy\": \"SingleRecord\",\n",
    "    \"TransformOutput\": {\n",
    "        \"S3OutputPath\": 's3://{}/{}/{}/output/'.format(BUCKET, EXP_NAME, batch_job_name)\n",
    "    },\n",
    "    \"TransformInput\": {\n",
    "        \"DataSource\": {\n",
    "            \"S3DataSource\": {\n",
    "                \"S3DataType\": \"S3Prefix\",\n",
    "                \"S3Uri\": batch_input\n",
    "            }\n",
    "        },\n",
    "        \"ContentType\": \"application/x-image\",\n",
    "        \"SplitType\": \"None\",\n",
    "        \"CompressionType\": \"None\"\n",
    "    },\n",
    "    \"TransformResources\": {\n",
    "            \"InstanceType\": \"ml.p2.xlarge\",\n",
    "            \"InstanceCount\": 1\n",
    "    }\n",
    "}\n",
    "\n",
    "print('Transform job name: {}'.format(batch_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client = boto3.client('sagemaker')\n",
    "sagemaker_client.create_transform_job(**request)\n",
    "\n",
    "print(\"Created Transform job with name: \", batch_job_name)\n",
    "\n",
    "while(True):\n",
    "    response = sagemaker_client.describe_transform_job(TransformJobName=batch_job_name)\n",
    "    status = response['TransformJobStatus']\n",
    "    if status == 'Completed':\n",
    "        print(\"Transform job ended with status: \" + status)\n",
    "        break\n",
    "    if status == 'Failed':\n",
    "        message = response['FailureReason']\n",
    "        print('Transform failed with the following error: {}'.format(message))\n",
    "        raise Exception('Transform job failed') \n",
    "    time.sleep(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspect the results\n",
    "\n",
    "The following cell plots the predicted bounding boxes for our example image. You'll notice that inside the function `get_predictions`, we filter the output to only include bounding boxes with a confidence score above a certain threshold (in this case, 0.2). This is because the object detection model we have trained always ouputs a fixed number of box candidates, and we must include a cutoff to eliminate the spurious results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_client = boto3.client('s3')\n",
    "\n",
    "batch_output = 's3://{}/{}/{}/output/'.format(BUCKET, EXP_NAME, batch_job_name)\n",
    "\n",
    "\n",
    "def list_objects(s3_client, bucket, prefix):\n",
    "    response = s3_client.list_objects(Bucket=bucket, Prefix=prefix)\n",
    "    objects = [content['Key'] for content in response['Contents']]\n",
    "    return objects\n",
    "\n",
    "\n",
    "def get_predictions(s3_client, bucket, prefix):\n",
    "    filename = prefix.split('/')[-1]\n",
    "    s3_client.download_file(bucket, prefix, filename)\n",
    "    with open(filename) as f:\n",
    "        data = json.load(f)\n",
    "        predictions = data['prediction']\n",
    "\n",
    "    # Return only the predictions with confidence above the threshold of 0.2.\n",
    "    return [prediction for prediction in predictions if prediction[1] > 0.2]\n",
    "\n",
    "\n",
    "def make_predicted_image(predictions, img_id, uri):\n",
    "    '''Maked a BoxedImage object with output of batch/realtime prediction.\n",
    "\n",
    "    Args:\n",
    "      predictions: list, output of get_predictions.\n",
    "      uri: str, s3 uri of input image.\n",
    "\n",
    "    Returns:\n",
    "      BoxedImage object with predicted bounding boxes.\n",
    "    '''\n",
    "    img = BoxedImage(id=img_id, uri=uri)\n",
    "    img.download('.')\n",
    "    imread_img = img.imread()\n",
    "    imh, imw, *_ = imread_img.shape\n",
    "\n",
    "    # Create boxes.\n",
    "    for batch_data in batch_boxes_data:\n",
    "        class_id, confidence, xmin, ymin, xmax, ymax = batch_data\n",
    "        boxdata = {'class_id': class_id,\n",
    "                   'height': (ymax-ymin)*imh,\n",
    "                   'width': (xmax-xmin)*imw,\n",
    "                   'left': xmin*imw,\n",
    "                   'top': ymin*imh}\n",
    "        box = BoundingBox(boxdata=boxdata, image_id=img.id)\n",
    "        img.consolidated_boxes.append(box)\n",
    "\n",
    "    return img\n",
    "\n",
    "\n",
    "inputs = list_objects(s3_client, BUCKET, urlparse(\n",
    "    batch_input).path.lstrip('/'))\n",
    "print(\"Input: \" + str(inputs[:2]))\n",
    "\n",
    "outputs = list_objects(s3_client, BUCKET, urlparse(\n",
    "    batch_output).path.lstrip('/'))\n",
    "print(\"Output: \" + str(outputs[:2]))\n",
    "\n",
    "# Download prediction results.\n",
    "batch_boxes_data = get_predictions(s3_client, BUCKET, outputs[0])\n",
    "batch_uri = f's3://{BUCKET}/{inputs[0]}'\n",
    "batch_img = make_predicted_image(batch_boxes_data, 'BatchTest', batch_uri)\n",
    "\n",
    "# Plot the image and predicted boxes.\n",
    "fig, ax = plt.subplots()\n",
    "batch_img.plot_consolidated_bbs(ax)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Realtime Inference\n",
    "\n",
    "We now host the model with an endpoint and perform realtime inference.\n",
    "\n",
    "This section involves several steps:\n",
    "\n",
    "1. Create endpoint configuration - Create a configuration defining an endpoint.\n",
    "2. Create endpoint - Use the configuration to create an inference endpoint.\n",
    "3. Perform inference - Perform inference on some input data using the endpoint.\n",
    "4. Clean up - Delete the endpoint and model\n",
    "\n",
    "### Create Endpoint Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "endpoint_config_name = job_name_prefix + '-epc' + timestamp\n",
    "endpoint_config_response = sagemaker_client.create_endpoint_config(\n",
    "    EndpointConfigName = endpoint_config_name,\n",
    "    ProductionVariants=[{\n",
    "        'InstanceType':'ml.m4.xlarge',\n",
    "        'InitialInstanceCount':1,\n",
    "        'ModelName':model_name,\n",
    "        'VariantName':'AllTraffic'}])\n",
    "\n",
    "print('Endpoint configuration name: {}'.format(endpoint_config_name))\n",
    "print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Endpoint\n",
    "\n",
    "The next cell creates an endpoint that can be validated and incorporated into production applications. This takes about 10 minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "endpoint_name = job_name_prefix + '-ep' + timestamp\n",
    "print('Endpoint name: {}'.format(endpoint_name))\n",
    "\n",
    "endpoint_params = {\n",
    "    'EndpointName': endpoint_name,\n",
    "    'EndpointConfigName': endpoint_config_name,\n",
    "}\n",
    "endpoint_response = sagemaker_client.create_endpoint(**endpoint_params)\n",
    "print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))\n",
    "\n",
    "# get the status of the endpoint\n",
    "response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = response['EndpointStatus']\n",
    "print('EndpointStatus = {}'.format(status))\n",
    "\n",
    "# wait until the status has changed\n",
    "sagemaker_client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)\n",
    "\n",
    "# print the status of the endpoint\n",
    "endpoint_response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = endpoint_response['EndpointStatus']\n",
    "print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n",
    "\n",
    "if status != 'InService':\n",
    "    raise Exception('Endpoint creation failed.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform inference\n",
    "\n",
    "The following cell transforms the image into the appropriate format for realtime prediction, submits the job, receives the prediction from the endpoint, and plots the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(test_image, 'rb') as f:\n",
    "    payload = f.read()\n",
    "    payload = bytearray(payload)\n",
    "\n",
    "client = boto3.client('sagemaker-runtime')\n",
    "response = client.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='application/x-image', \n",
    "                                   Body=payload)\n",
    "\n",
    "result = response['Body'].read()\n",
    "result = json.loads(result)\n",
    "predictions = [prediction for prediction in result['prediction'] if prediction[1] > .2]\n",
    "realtime_uri = batch_uri\n",
    "realtime_img = make_predicted_image(predictions, 'RealtimeTest', realtime_uri)\n",
    "\n",
    "# Plot the realtime prediction.\n",
    "fig, ax = plt.subplots()\n",
    "realtime_img.download('.')\n",
    "realtime_img.plot_consolidated_bbs(ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clean up\n",
    "\n",
    "Finally, let's clean up and delete this endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boto3.client(service_name='sagemaker').delete_endpoint(EndpointName=endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Review\n",
    "\n",
    "We covered a lot of ground in this notebook! Let's recap what we accomplished. First we started with an unlabeled dataset (technically, the dataset was previously labeled by the authors of the dataset, but we discarded the original labels for the purposes of this demonstration). Next, we created a SageMake Ground Truth labeling job and generated new labels for all of the images in our dataset. Then we split this file into a training set and a validation set and trained a SageMaker object detection model. Next, we trained a new model using these Ground Truth results and submitted a batch job to label a held-out image from the original dataset. Finally, we created a hosted model endpoint and used it to make a live prediction for the same held-out image."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
